n <- 500
p <- 10
beta <- 1
noise_sd <- 1
data_ls <- list()
vimp_ls <- list()
fit_ls <- list()
# case 1: independent features
X <- as.data.frame(matrix(rnorm(n * p), n, p))
y <- X[, 1] * beta + X[, 2] * beta + rnorm(n, sd = noise_sd)
data_ls[["Independent Features"]] <- data.frame(y = y, X)
# case 2: correlated non-signal features
p_cor <- 100
rho <- 0.9
sig <- matrix(rho, p_cor, p_cor)
X_cor <- MASS::mvrnorm(n, mu = rep(0, p_cor), Sigma = sig)
data_ls[["Correlated Non-signal Features"]] <- dplyr::bind_cols(
data.frame(y = y, X),
as.data.frame(X_cor) |> dplyr::rename_with(~ paste0("corr", .x))
)
# case 3: correlated signal features (one signal)
sig <- matrix(rho, p_cor + 1, p_cor + 1)
X_cor <- MASS::mvrnorm(n, mu = rep(0, p_cor + 1), Sigma = sig)
X[, 1] <- X_cor[, 1]
X_cor <- X_cor[, -1]
y <- X[, 1] * beta + X[, 2] * beta + rnorm(n, sd = noise_sd)
data_ls[["Correlated Signal Features (1 signal)"]] <- dplyr::bind_cols(
data.frame(y = y, X),
as.data.frame(X_cor) |> dplyr::rename_with(~ paste0("corr", .x))
)
# case 4: correlated signal features (many signal)
y <- X[, 1] * beta + X[, 2] * beta + rowSums(X_cor[, 1:4] * beta) + rnorm(n, sd = noise_sd)
data_ls[["Correlated Signal Features (many signal)"]] <- dplyr::bind_cols(
data.frame(y = y, X),
as.data.frame(X_cor) |> dplyr::rename_with(~ paste0("corr", .x))
)
for (sim_name in names(data_ls)) {
cat(sprintf("\n\n## %s {.tabset .tabset-pills .tabset-square}\n\n", sim_name))
data <- data_ls[[sim_name]]
y <- data |>
dplyr::pull(y)
X <- data |>
dplyr::select(-y)
# linear regression
lm_fit <- lm(y ~ ., data = data)
lm_vimp_df <- tibble::tibble(
var = names(summary(lm_fit)$coefficients[-1, 1]),
vimp = summary(lm_fit)$coefficients[-1, 1],
se = summary(lm_fit)$coefficients[-1, 2]
)
# LASSO regression
lasso_fit <- glmnet::cv.glmnet(
x = as.matrix(X),
y = y,
alpha = 1,
nfolds = 5
)
lasso_vimp_df <- tibble::tibble(
var = rownames(coef(lasso_fit, s = "lambda.min"))[-1],
vimp = as.matrix(coef(lasso_fit, s = "lambda.min"))[-1]
)
# ridge regression
ridge_fit <- glmnet::cv.glmnet(
x = as.matrix(X),
y = y,
alpha = 0,
nfolds = 5
)
ridge_vimp_df <- tibble::tibble(
var = rownames(coef(ridge_fit, s = "lambda.min"))[-1],
vimp = as.matrix(coef(ridge_fit, s = "lambda.min"))[-1]
)
# random forest (MDI)
rf_fit <- ranger::ranger(
data = data,
formula = y ~ .,
importance = "impurity"
)
rf_vimp_mdi_df <- tibble::tibble(
var = names(rf_fit$variable.importance),
vimp = rf_fit$variable.importance
)
# random forest (permutation)
rf_fit <- ranger::ranger(
data = data,
formula = y ~ .,
importance = "permutation"
)
rf_vimp_perm_df <- tibble::tibble(
var = names(rf_fit$variable.importance),
vimp = rf_fit$variable.importance
)
# random forest (feature occlusion)
oob_errs <- c() # using out-of-bag error as the metric for simplicity (should generally use held-out test set)
for (j in names(rf_fit$variable.importance)) {
X_loco_j <- X |>
dplyr::select(-tidyselect::all_of(j))
rf_fit_j <- ranger::ranger(
data = cbind(y = y, X_loco_j),
formula = y ~ .
)
oob_errs[j] <- rf_fit_j$prediction.error
}
rf_vimp_loco_df <- tibble::tibble(
var = names(rf_fit$variable.importance),
vimp = oob_errs - rf_fit$prediction.error
)
# random forest (shap)
rf_fit <- ranger::ranger(
data = data,
formula = y ~ .
)
pred_fun <- function(object, newdata) {
predict(object, newdata)$predictions
}
shap_values <- fastshap::explain(
object = rf_fit,
X = X,
pred_wrapper = pred_fun,
nsim = 10
)
rf_vimp_shap_df <- as.data.frame(abs(shap_values)) |>
dplyr::summarise(
dplyr::across(
tidyselect::everything(),
~ mean(.x)
)
) |>
tidyr::pivot_longer(
cols = tidyselect::everything(),
names_to = "var",
values_to = "vimp"
)
vimp_df <- list(
Linear = lm_vimp_df,
LASSO = lasso_vimp_df,
Ridge = ridge_vimp_df,
`RF (MDI)` = rf_vimp_mdi_df,
`RF (permute)` = rf_vimp_perm_df,
`RF (SHAP)` = rf_vimp_shap_df,
`RF (LOCO)` = rf_vimp_loco_df
) |>
dplyr::bind_rows(.id = "method") |>
dplyr::mutate(
method = forcats::fct_inorder(method),
color = dplyr::case_when(
var == "V1" ~ "Signal1",
var == "V2" ~ "Signal2",
stringr::str_detect(var, "corr") ~ "Correlated",
TRUE ~ "Other"
),
var = forcats::fct_inorder(var)
)
plt <- vimp_df |>
ggplot2::ggplot() +
ggplot2::aes(x = var, y = vimp, fill = color) +
ggplot2::geom_bar(stat = "identity") +
ggplot2::facet_wrap(~ method, ncol = 1, scales = "free_y") +
ggplot2::labs(x = "Feature", y = "Importance", fill = "") +
ggplot2::scale_fill_manual(
values = c(
Signal1 = "dodgerblue",
Signal2 = "darkgreen",
Correlated = "orange",
Other = "gray"
)
) +
vthemes::theme_vmodern(
x_text_angle = TRUE,
size_preset = "large"
)
fig_width <- dplyr::case_when(
ncol(data) > 20 ~ 20,
TRUE ~ 10
)
vthemes::subchunkify(
plt, i = subchunk_idx, fig_height = 16, fig_width = fig_width
)
subchunk_idx <- subchunk_idx + 1
fit_ls[[sim_name]] <- list(
Linear = lm_fit,
LASSO = lasso_fit,
Ridge = ridge_fit,
`RF` = rf_fit
)
vimp_ls[[sim_name]] <- vimp_df
}